

#' DID check function for the basic design
#' @inheritParams did_std
#' @return A list of placebo estimates and plots.
#' @keywords internal
#' @importFrom stats sd
did_check_std <- function(
  formula, data, id_unit, id_time, is_panel = TRUE, option
) {
  ## input check
  if (isFALSE(is_panel)) id_unit <- NULL
  if (isTRUE(is_panel) && is.null(id_unit)) {
    stop("A vaiable name should be provided to id_unit.")
  }

  ## prepare formulas
  fm_prep <- did_formula(formula, is_panel)

  ##
  ## handle cluster variable
  ##
  var_cluster <- option$id_cluster
  if (is.null(var_cluster) && isTRUE(is_panel)) {
    var_cluster <- "id_unit"
    var_cluster_pre <- id_unit
  }

  ## --------------------------------------------
  ## transform data
  ## --------------------------------------------
  if (isTRUE(is_panel)) {
    dat_did <- did_panel_data(
      fm_prep$var_outcome, fm_prep$var_treat, fm_prep$var_covars,
      var_cluster_pre, id_unit, id_time, data
    )
  } else {
    dat_did <- did_rcs_data(
      fm_prep$var_outcome, fm_prep$var_treat, fm_prep$var_post,
      fm_prep$var_covars, var_cluster,id_time, data
    )
  }

  ## --------------------------------------------
  ## estimate placebo test statistics
  ## --------------------------------------------
  did_placebo_est <- did_std_placebo(fm_prep$fm_did[[1]], dat_did, option$lag)

  ## --------------------------------------------
  ## compute std.error via bootstrap
  ## --------------------------------------------
  if (is.null(var_cluster)) {
    id_cluster_vec <- 1:nrow(dat_did)
  } else {
    id_cluster_vec <- unique(pull(dat_did, !!sym(var_cluster)))
  }

  ## setup worker
  setup_parallel(option$parallel)

  ## use future_lapply to implement the bootstrap parallel
  est_boot <- future_lapply(1:option$n_boot, function(i) {
    tryCatch({
      did_std_placebo_boot(fm_prep, dat_did, id_cluster_vec, var_cluster, is_panel, option$lag)
    }, error = function(e) {
      NULL
    })
  }, future.seed = TRUE)
  est_boot <- est_boot[lengths(est_boot) != 0]

  ## --------------------------------------------
  ## summarize results
  ## --------------------------------------------
  est_boot_std <- do.call(rbind, map(est_boot, ~.x$est_std))
  est_boot <- do.call(rbind, map(est_boot, ~.x$est))


  estimates <- vector("list", length = length(option$lag))
  for (i in 1:length(option$lag)) {
    estimates[[i]] <- data.frame(
      estimate       = did_placebo_est$est_std[i],
      lag            = option$lag[i],
      std.error      = sd(est_boot_std[,i]),
      estimate_orig  = did_placebo_est$est[i],
      std.error_orig = sd(est_boot[,i])
    )
  }

  estimates <- as_tibble(bind_rows(estimates))
  ## --------------------------------------------
  ## generate a DID plot
  ## --------------------------------------------
  p1 <- did_std_plot(dat_did)
  p2 <- did_sad_plot(estimates)

  return(list(est = estimates, plot = list(p1, p2)))
}


#' Run a placebo regression on the pretreatment outcome
#' @keywords internal
#' @param formula A formula generated by \code{did_formula} function.
#' @param data An output from \code{did_panel_data} or \code{did_rcs_data} function.
#' @param A vector of non-negative lag parameters.
#' @importFrom dplyr %>% mutate filter pull
#' @importFrom stats lm sd
#' @return A list of placebo effects (\code{est}),
#'   and standardized effects (\code{est_std})
did_std_placebo <- function(formula, data, lags) {
  ## remove all infeasible lag values
  lags <- abs(lags)
  max_lag <- abs(min(data$id_time_std))
  lags <- lags[lags < max_lag]

  ## run placebo regression
  est <- est_std <- rep(NA, length(lags))
  for (i in 1:length(lags)) {
    time_use <- c(-lags[i], -lags[i]-1)
    dat_use <- data %>%
                mutate(It = ifelse(.data$id_time_std >= -lags[i], 1, 0)) %>%
                filter(.data$id_time_std %in% time_use)
    fit <- lm(formula, data = dat_use)
    est[i] <- fit$coef['Gi:It']

    ## compute the std version
    ## normalize by the control group mean and sd
    ct_outcome <- dat_use %>%
      filter(.data$It == 0 & .data$Gi == 0) %>%
      pull(.data$outcome)
    dat_use$outcome <- (dat_use$outcome - mean(ct_outcome, na.rm = TRUE)) / sd(ct_outcome, na.rm = TRUE)
    fit_std <- lm(formula, data = dat_use)
    est_std[i] <- fit_std$coef['Gi:It']
  }

  names(est) <- names(est_std) <- lags
  return(list(est = est, est_std = est_std))
}


did_std_placebo_boot <- function(
  fm_prep, dat_did, id_cluster_vec, var_cluster, is_panel, lag
) {
  ## sample index
  id_boot <- sample(id_cluster_vec,
    size = length(id_cluster_vec), replace = TRUE
  )

  ## create dataset
  dat_tmp <- list()
  for (j in 1:length(id_boot)) {
    if (is.null(var_cluster)) {
      id_tmp <- id_boot[j]
    } else {
      id_tmp <- which(dat_did[,var_cluster] == id_boot[j])
    }
    dat_tmp[[j]] <- dat_did[id_tmp, ]
    dat_tmp[[j]]$id_unit <- j
  }

  ## create did_data object
  if (isTRUE(is_panel)) {
    dat_boot <- did_panel_data(
      var_outcome = "outcome", var_treat = 'treatment', fm_prep$var_covars,
      var_cluster, id_unit = "id_unit", id_time = 'id_time',
      data = do.call(rbind, dat_tmp)
    )
  } else {
    dat_boot <- did_rcs_data(
      var_outcome = "outcome", var_treat = "Gi", var_post = "It",
      fm_prep$var_covars, var_cluster,
      id_time = "id_time", do.call(rbind, dat_tmp)
    )
  }

  ## fit DID and sDID
  est <- did_std_placebo(fm_prep$fm_did[[1]], dat_boot, lag)
  return(est)
}




#' Create a did plot for standard design
#' @keywords internal
#' @param data A data object from \code{did_panel_data} or \code{did_rcs_data}
#' @return A ggplot object
#' @importFrom ggplot2 ggplot geom_line geom_point aes geom_vline labs theme_bw scale_color_manual
#' @importFrom dplyr %>% across group_by summarise mutate ungroup select
#' @importFrom stats qnorm
did_std_plot <- function(data) {
  dat_plot <- data %>% group_by(.data$id_time_std, .data$Gi) %>%
         summarise(across(.data$outcome, list(mean = mean, sd = sd))) %>%
         mutate(group = ifelse(.data$Gi == 1, "Treated", "Control")) %>%
         select(.data$group, time_to_treat = .data$id_time_std,
                .data$outcome_mean, std.error = .data$outcome_sd) %>%
         mutate(CI90_UB = .data$outcome_mean + qnorm(0.95) * .data$std.error,
                CI90_LB = .data$outcome_mean - qnorm(0.95) * .data$std.error) %>%
         ungroup()
  gg <- ggplot(dat_plot, aes(x = .data$time_to_treat, y = .data$outcome_mean, color = .data$group)) +
          geom_vline(xintercept = 0, linetype = 'dashed') +  geom_line() + geom_point() +
          labs(x = "Time relative to treatment assignment", y = "Mean Outcome", color = "Group") +
          scale_color_manual(values = c("gray50", '#1E88A8')) +
          theme_bw()
  return(list(plot = gg, dat_plot = dat_plot))
}
